{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EW6RGZaQQMZF"
      },
      "source": [
        "# Train an Adapter for NER"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pWm2YMkRQyKh"
      },
      "source": [
        "This notebook illustrates how you can train an adapter and head for a tagging task. We are using the CoNLL 2003 dataset to train the model on Named Entity Recognition (NER). Additionally, we will set and save the id2label dictionary so the model can easily be used by someone else. First, we need to install the 'adapters' and the 'datasets' package."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T12:16:27.919747100Z",
          "start_time": "2023-08-17T12:16:19.394914400Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hXGi9Hx1QJsw",
        "outputId": "989cbafa-df85-4751-cfd0-027df156f3c7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m204.3/204.3 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m22.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m44.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m49.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -Uq adapters\n",
        "!pip install -q datasets\n",
        "!pip install -q scikit-learn\n",
        "!pip install -Uq accelerate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GFzFSmj9QkRc"
      },
      "source": [
        "Next, we instantiate the model, add a tagging head, and set the right label2id dictionary. We add an adapter that will be trained on the task of NER."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:35:35.326969200Z",
          "start_time": "2023-08-17T10:35:27.729402300Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 194,
          "referenced_widgets": [
            "7bfc8e75caff48dcbac9d265e2edbc4b",
            "384d5a3c01fd453baafca8e7cf1ddc78",
            "786aef51bda64759a053eca2917e62c3",
            "734eb7714f014053a675a86fa0900fca",
            "65bb83fa1c994335ba522109d99e8f0f",
            "e56a146788434d07ac7fddd6245adbc1",
            "ed619138f1004153979b1f4bc3581c35",
            "b3cce7e38f01492095d384c907bd2da2",
            "93325ff9d063497c83d14cfe549669c3",
            "c1ae34552bd844d5a80cc1bd10d0d97e",
            "ea71f0789c0f4b54bfee1a9ded4fac41",
            "44674a5161ba42588d5d4db56e0456bd",
            "4889e72f993d4ad7a3408b916e93107c",
            "1cbfc55859d84f5e99aa62e4b8113096",
            "76da26ee3c724234a690a249a7c487ec",
            "65e31ebdd9d446748cae0d80496210d6",
            "41a349546ec24a4fbc5308cec3bb51a9",
            "5a6a07766d7b4531adf0f2770b9d80e9",
            "ab94ecdc4152485caca91443c246b9c9",
            "128d595bf61d4f74823cf295a7076802",
            "af0dfa74bebf43c193262d9e62b24dca",
            "9969929b319749d3b28ee34b62bc37c5",
            "ce7bd37fe4644a6fa1f582716fa540a2",
            "940bc33655cd4e2fae99728ac80aef11",
            "58a4476e1c9b43b58fd71ffc3552ac5b",
            "f7ed848702f0416fa7e63f671ed474d0",
            "a5585e29591646c083355fd5a569a23a",
            "b2dae75082e84572a5deab791b7126c2",
            "76039cb548bc4094b830b39e9576c3cf",
            "a507d3ffd0454d469e474fa68d422b61",
            "e8b5fdf5de1047e89b498e3c60bcd7ac",
            "823e9944785949d29bffcb04899ceffc",
            "f50200a074bc45c7a9c7d2541d481c26",
            "75f015d6391c4e94b2ce88b99b3424b3",
            "5d73def7177843cca6bf34e5e436eff9",
            "814dc8eb4be54652b30db756eaecbdec",
            "25a82657a99248f08b51b0ef6283e4c7",
            "1e633a650c3d42c1817a51830f9b2e6b",
            "c7c5063efdb84755b32eb0039aa4fc2f",
            "8ff3bc2c677d46f78564c7b78918bc8d",
            "5731be7c90794b498ed4643d76fbc480",
            "7dce5f00ca074667b91e77c007d2a570",
            "d9f5d4e711ae49aca931e0c2c8076b0d",
            "8c5b1df836fe45128ebccec5b578e016",
            "f76f0efc922f4a5fb926ee6436c55e61",
            "a6bd14b364af4eadbb131e690437f017",
            "36ec80a69a6e4a848983d564ae1c8934",
            "18e4f495ecd84a7b944364952eea0f28",
            "be7ab9cbe68c4d2d896ee623846fb948",
            "7bf305f3bba14d35952b5eacc42cb18d",
            "a3f74a47268a4186910efa440864b133",
            "b7fdb9d52b3042199891a520d79b879f",
            "185e9f0ebb8040a5bff3b736f561f977",
            "ab6b913b1d3b465fa90ab9756dd99388",
            "a47717747a3349009a9dd8d66b2faa10"
          ]
        },
        "id": "x7OeDB9KQeDe",
        "outputId": "e1cf82ac-b96c-4f2d-cdf5-26bf46881333"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7bfc8e75caff48dcbac9d265e2edbc4b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "44674a5161ba42588d5d4db56e0456bd",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ce7bd37fe4644a6fa1f582716fa540a2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "75f015d6391c4e94b2ce88b99b3424b3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f76f0efc922f4a5fb926ee6436c55e61",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-MISC', 'I-MISC']\n"
          ]
        }
      ],
      "source": [
        "from adapters import AutoAdapterModel\n",
        "from transformers import AutoTokenizer, AutoConfig\n",
        "from datasets import load_dataset\n",
        "from torch.utils.data import Dataset\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from tqdm.notebook import tqdm\n",
        "from torch import nn\n",
        "#The labels for the NER task and the dictionaries to map the to ids or\n",
        "#the other way around\n",
        "labels = [\"O\", 'B-LOC', \"I-LOC\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-MISC\", \"I-MISC\"]\n",
        "id2label = {id_: label for id_, label in enumerate(labels)}\n",
        "label2id = {label: id_ for id_, label in enumerate(labels)}\n",
        "\n",
        "model_name = \"bert-base-uncased\"\n",
        "config = AutoConfig.from_pretrained(model_name, num_label=len(labels), id2label=id2label, label2id=label2id)\n",
        "model = AutoAdapterModel.from_pretrained(model_name)\n",
        "model.add_adapter(\"ner\")\n",
        "\n",
        "model.add_tagging_head(\"ner_head\", num_labels=len(labels), id2label=id2label)\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "\n",
        "print(model.get_labels())\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9UIzxvuRRyaI"
      },
      "source": [
        "BERT expects a word piece tokenized text. The tokens provided by the dataset are tokenized differently. The `encode_labels` function maps the labels of the CoNLL 2003 dataset to the word piece tokens. The `encode_data` encodes the tokens as ids and adds the special tokens so the BERT model can handle the input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:37:03.810633400Z",
          "start_time": "2023-08-17T10:37:03.779389500Z"
        },
        "id": "TyewMiteRHkv"
      },
      "outputs": [],
      "source": [
        "def encode_data(data):\n",
        "    encoded = tokenizer([\" \".join(doc) for doc in data[\"tokens\"]], pad_to_max_length=True, padding=\"max_length\",\n",
        "                        max_length=512, truncation=True, add_special_tokens=True)\n",
        "    return (encoded)\n",
        "\n",
        "\n",
        "def encode_labels(example):\n",
        "    r_tags = []\n",
        "    count = 0\n",
        "    token2word = []\n",
        "    for index, token in enumerate(tokenizer.tokenize(\" \".join(example[\"tokens\"]))):\n",
        "        if token.startswith(\"##\") or (token in example[\"tokens\"][index - count - 1].lower() and index - count - 1 >= 0):\n",
        "            # If the token is part of a larger token and not the first we need to differentiate.\n",
        "            # If it is a B (beginning) label the next one needs to be assigned an I (intermediate) label.\n",
        "            # Otherwise they can be labeled the same.\n",
        "            if r_tags[-1] % 2 == 1:\n",
        "                r_tags.append(r_tags[-1] + 1)\n",
        "            else:\n",
        "                r_tags.append(r_tags[-1])\n",
        "            count += 1\n",
        "        else:\n",
        "            r_tags.append(example[\"ner_tags\"][index - count])\n",
        "\n",
        "        token2word.append(index - count)\n",
        "    r_tags = torch.tensor(r_tags)\n",
        "    labels = {}\n",
        "    # Pad token to maximum length for using batches\n",
        "    labels[\"labels\"] = F.pad(r_tags, pad=(1, 511 - r_tags.shape[0]), mode='constant', value=0)\n",
        "    # Truncate if the document is too long\n",
        "    labels[\"labels\"] = labels[\"labels\"][:512]\n",
        "    return labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YEasTbG1TOXP"
      },
      "source": [
        "Next, we can load the dataset and use the previously defined functions to prepare the dataset for training. We then define two dataloaders: one for training and one for evaluation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:37:31.240106900Z",
          "start_time": "2023-08-17T10:37:07.227835100Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 433,
          "referenced_widgets": [
            "3d291c6197fe445097acf56c1111058a",
            "6ebc65db043640ab9a6703becffd70b0",
            "37574a56d8d54563bc135d73f73177d2",
            "8ba5a7ce31e14d88b9ec4f1bf77ab1b7",
            "5942c5c18d1d4a23860dd517c6d98aa8",
            "406febf96b5d43a786590da9fbd0bb64",
            "0e65d3f0f6d04830a7efd2b6486b74b6",
            "9ef45f7f80444d25ab9d4928d2e3cda9",
            "deff317b3a0946a1801c293d016aa8e6",
            "c5242f3c99874c49b40b10c6c69a56fe",
            "aa9b2cba91fa44ceb75636485317025d",
            "45e1489c33a84428bbe66f5fd02e5185",
            "b20dc0bd3d4c43588a98d12138f76ff6",
            "d9eb5930d3d44130b430f605c4bde901",
            "5621728678bf4cdeb0749f56d7ae0af8",
            "7aecc20ea2b943e99dce22113cfac39f",
            "984762abd094417aa2577c557c324d51",
            "8c7e6b5372f646ada8e1d53f125a008d",
            "b9145a370ada42459ed865a8f6769a02",
            "8ed8974bf3c54a40b22bd09a48dd20f9",
            "c648f36efcbb424bb62090a8564a8088",
            "68f6bdd4adde4a9891dacaaf3def87f8",
            "380de2077e854014ada5a8d3d20e672b",
            "b2393973ef3745c1a56c0aee688a55f1",
            "c6b853e934a34ac393ae87db5188af0c",
            "4e1260ae1da341e690068effac02adc3",
            "982597ac1c074c99bd451f2b6a8ff57a",
            "255dda130ab34a86babe7133e2f57975",
            "37846aed81694b1c8951abcbd4f3f2a1",
            "58b2f8e3a737454aa873f7b42f45f6b9",
            "7751033ac3ba429ab2b4cebc757b16fa",
            "b749a1e744a840ebb6a6b92eaf156ef2",
            "1d1fe528148b4166aff16406e5681f71",
            "9a9137dc96bb42068bed8909b726884c",
            "35d54fcf36f449fda6354c905cf71750",
            "a0b81e0aa2ec4726b3a31205c61eabcb",
            "7d3148d0914341b3a9570b8cbcdf09aa",
            "d88ba4f3b8ac461da4c79b37c5664fb4",
            "c580599986f44383a5caaf88874ce869",
            "5aee7c3cfc324ecfa42e1b34af62cb83",
            "20a907859a8e47f2af776476b7442f25",
            "29c5b70e74c94803bee6f721d905b3a6",
            "ea14b46bf8894bb3a39f111870bdc5f1",
            "fbfdb611441d4376a65b5e4c5f5b133b",
            "0c72aae7a2d64033b9035d1a31973238",
            "e8bc9381aaf543859d92dd80090dbc73",
            "6c6bf2152667496abeaed533d95d5a37",
            "e04bf0b951664c37b0a9dd02ac7d970f",
            "e2ef1328cec843019b36b52f0d99e7f7",
            "022b7d117a45410383a254e996e29d09",
            "b9351628430d421b9f28f75833f85980",
            "27374c964b8c4654b9dc49f52da603ea",
            "8c54ac8f2a274accbed8fc0344b88ec5",
            "ab0ec833f9ae41f08a3ae167516479cf",
            "8e7d525cca594cafaab43e55267d2fba",
            "d7861e4101434d08ba278c326eb656aa",
            "03bed012ac9d4260a7d44b8de3027656",
            "f438d56ba2654f70b195c54662c49958",
            "8cd9ec3c313f47148f95c6ad87bf9709",
            "60e22df85fab476caf7f982ebdbbd298",
            "0379776dc7504b869cdb07ad463fb4f6",
            "8729c0cdd52743c98cfe6cbbb15c77cb",
            "9e21e7327edd4890b918e251237906c1",
            "de124fa44c1d4eba9d4648da89890d8a",
            "4f4649ca703e429ea6aca8bc3e074a93",
            "eea2739a28fe4f2c98d7c7e46378bb61",
            "4266e0ec583e426d8346d7474eeff1e0",
            "038e22ef6109472ba57c94afa16c0352",
            "712c1aa999bc46d58c91f89dad5d4b22",
            "25b26a9a75c244b48b50a5cdda772628",
            "6765e82a5c4d4f3dbec74f6fcc3b253a",
            "91d5d2e610aa4d61b71264f2aedde3a0",
            "a0fe6ad97c3b401798a5920ecbe7c4ee",
            "b3da3f3770d64cc88b9ede2a6248b7a4",
            "e0e2ca1e0a8544fdb9a8661dfef3ba25",
            "c2c7cfa4e3694584b68240137c9ab197",
            "7b0670e1bc9f4d6a8f9add5bd2ba9337",
            "de6ab5017f4c4b5a99afc66bb720432a",
            "6249bd34982c4ea5b1b851865a751b07",
            "29841c4f40094203a139407ba96b49b5",
            "9c8f952de0c04a5ebb1d3204bdf39fcf",
            "aba8850dc055442ca0b07615acb1aaf2",
            "6b593bde0abf4d61aeb43e9a69ac0c98",
            "ce3965b2cd9c417d99d3b053e72f1f78",
            "dd32f0c8f11a41cc8daedfb1558e582f",
            "4cd9a5ccb2aa410fbfee1fe163c70f4d",
            "d9ef780933b04ea288284edf03d2af49",
            "6c024de190b745c7a34d439f383c2278",
            "15a246e075f24bbe94d44b1f690dc01e",
            "a5efd1e773c24824aa7e6cc73ed32918",
            "ce1a4f192e8140fc88f9df676152ada1",
            "d314fe66bdf148a19fede5de893859cb",
            "bb35c95a043048db98c841e9b4bfd155",
            "26ae54b78e9c49d88c8e1fb5a69bc5cf",
            "689a843d169148efa35118918892c24c",
            "24967c0ff753454e9c05a6dd6fa598ba",
            "53e719c80e9b4948a337abfac360d848",
            "cb66ec55fcaa40b198820084dd713486",
            "5beb381412f64e26bbc3dcc1f2aabcb8",
            "275b2cc3243a4b8489c651c9e087afc4",
            "2318256688f74510955e4d7b5f9a3e1c",
            "0f05a0e520004494a530b8aae0c90fb6",
            "f963273bc53a4debbce32e9cd84b6cc6",
            "564270e2f2fc4c4a8ac9b1cc9f858402",
            "1e809f84a8fa430fbff71dcdfbb8d3bb",
            "30bbf3a537a647ff9f38db00dcfd6d54",
            "45f22e63fb6344b0880a7f2520737798",
            "114bd3bcd6c043d3a7e2835f7fc0918c",
            "7381a7f11caf425580da4476c97580a7",
            "3d4b9e0747a845058a2e434c459c42ae",
            "db9ca7e373fd45f2af952a6c24057bf7",
            "c13910d4393e47d2bb7175ec00b5ce04",
            "112b974e2bea45ac850d91b9abaf8612",
            "9d5e0a582280484d819b899427729b62",
            "6c13ee99887042f696d97bf7ba40dffe",
            "427dffa286154aada14ad19e9cff6bf1",
            "344b454686324a9c89e62d7c408fad3b",
            "446def96703a45fc8294a4a37772126b",
            "675f2c3dc54c4bcbbfedbcbcbc47f7b3",
            "e3f7e1afc509424399741c9adc18ee77",
            "0222a4baab4b4b389db7b7f090265f3a",
            "f1fb84da993c49b4b8aaa2debb555972",
            "7b05ac104a794d6fa7b29ad33e886443",
            "e2fd6cf53125487cb0bd8f9475457196",
            "8274db8dc7c643b1a236740be9f9a791",
            "b651ef5040d242ff88cc11139ecedbe8",
            "fc41a261e4ed4ec1b0b2c58c4c033336",
            "4cc56d786ee9418ab9ad381919a11007",
            "bdf965b8db5a4cd8aaafbe1c694329b2",
            "1e05a34d374b49f6b7a949cfc2067db9",
            "d011ceb10f714bb097d1cb6fbac1b528",
            "c763edb0beb4433e8a4a310bb7179845",
            "660ccb668a374bd7b7e85bf58253689b",
            "1522a72fbf0e4bd6a0cdcc19e39dfa49",
            "cb7de70754264093a8297baf587606b3",
            "50d2c9703810460f84b13f7a647f7a8c",
            "001e73d1cc7a4de8ace5a17e208ac45f",
            "2ebaa8e7258a4da58fe9dc2af2865c09",
            "1111d16d002242e89df56dc90b1e3c97",
            "83780c7fcae74a349d22db2be1e64881",
            "7b7927b1be274434998726116084b78d",
            "23a874e05b5e4ba4a7e401fd0e0bfb0a",
            "351f72a0ac4f4df084ed046cd775e27d"
          ]
        },
        "id": "cxcq1ss4Qi1z",
        "outputId": "15100772-5f96-447a-e754-24f604d455f2"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3d291c6197fe445097acf56c1111058a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading builder script:   0%|          | 0.00/9.57k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "45e1489c33a84428bbe66f5fd02e5185",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading metadata:   0%|          | 0.00/3.73k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "380de2077e854014ada5a8d3d20e672b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading readme:   0%|          | 0.00/12.3k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9a9137dc96bb42068bed8909b726884c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading data:   0%|          | 0.00/983k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0c72aae7a2d64033b9035d1a31973238",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d7861e4101434d08ba278c326eb656aa",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4266e0ec583e426d8346d7474eeff1e0",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "de6ab5017f4c4b5a99afc66bb720432a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/14041 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "15a246e075f24bbe94d44b1f690dc01e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/3250 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "275b2cc3243a4b8489c651c9e087afc4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/3453 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "db9ca7e373fd45f2af952a6c24057bf7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/14041 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f1fb84da993c49b4b8aaa2debb555972",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/3250 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "660ccb668a374bd7b7e85bf58253689b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/3453 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "dataset = load_dataset(\"conll2003\")\n",
        "dataset = dataset.map(encode_labels)\n",
        "dataset = dataset.map(encode_data, batched=True, batch_size=10)\n",
        "\n",
        "dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])\n",
        "\n",
        "dataloader = torch.utils.data.DataLoader(dataset[\"train\"])\n",
        "evaluate_dataloader = torch.utils.data.DataLoader(dataset[\"test\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M-47JZsvTqyC"
      },
      "source": [
        "Before we can start training the model, we need to define some training parameters. We check if a GPU is available for training and set our device accordingly. Then we can tell the model which adapter we want to train with `model.train_adapters(\"<adaper_name>\"))`. As loss function, we use Cross Entropy Loss. Finally, we need to define an optimizer for training with parameters and learning rate."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:37:58.588038600Z",
          "start_time": "2023-08-17T10:37:58.367488200Z"
        },
        "id": "AtZOJUC2RNbu"
      },
      "outputs": [],
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "model.to(device)\n",
        "model.set_active_adapters(\"ner\")\n",
        "model.train_adapter(\"ner\")\n",
        "\n",
        "loss_function = nn.CrossEntropyLoss()\n",
        "no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
        "optimizer_grouped_parameters = [\n",
        "                {\n",
        "                    \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
        "                    \"weight_decay\": 1e-5,\n",
        "                },\n",
        "                {\n",
        "                    \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
        "                    \"weight_decay\": 0.0,\n",
        "                },\n",
        "            ]\n",
        "optimizer = torch.optim.AdamW(params=optimizer_grouped_parameters, lr=1e-4)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PkVsDH58Ulr2"
      },
      "source": [
        "Then we can start the training. In this case, we trained the model for 2 epochs. Feel free to play around with the hyperparameters like the number of epochs, the learning rate, ... But keep in mind that adapters often need a few more training epochs than full finetuning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:38:42.956821700Z",
          "start_time": "2023-08-17T10:38:09.867520300Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 188,
          "referenced_widgets": [
            "94ed9cc4694645569427c490ba3013d0",
            "b2d089f21ac748f3ab25b1f42208e2cd",
            "dad0ed9c6f994d7caf50509f782fd21a",
            "31cf072a930a441ebac616d7c541f851",
            "aac0d2091a254cc3919b40fb6a286787",
            "184bf314a3b54fa9b077ae43708b8812",
            "48d0f198e18f4d6682846fa922853968",
            "0e2d1494809943d1a6eb38ca5a8e84d2",
            "8f4e2d5c779d4e71ae7dc30586669d09",
            "961cbc92e9e24656b653328f8d971bb5",
            "7b424f3e53c14605853f2c6fe8c863a3",
            "1795946250de4e6d9ddf9a0104c39125",
            "fd6996d1026e4783b882dcf0bbbcb526",
            "a0c58f0bf74a44e2ac95f2ef518cc679",
            "021505e21e814600bebb598d75aaac0f",
            "733ad5da593a4731b86a103bd6587446",
            "4b39608d1a8d48259389dd209620c0a4",
            "53f5321a5b1c40b780dea0259b7b5bd1",
            "14db68382cef4004adfc1d1857e9502c",
            "ed68de8db0984d99a30908cb96e175be",
            "6b2dbbba09ac4c0ebf0d7bcc9d75fa61",
            "0b7ac7c98323495a873761424b48fd2c"
          ]
        },
        "id": "g6R5Ge4RRSMo",
        "outputId": "bf06c3d2-1b7a-454e-cb28-8def4fcfe20e"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "94ed9cc4694645569427c490ba3013d0",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  0%|          | 0/14041 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "loss: 3.087712049484253\n",
            "loss: 0.0004019373736809939\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1795946250de4e6d9ddf9a0104c39125",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  0%|          | 0/14041 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "loss: 0.002754304325208068\n",
            "loss: 0.00028402122552506626\n"
          ]
        }
      ],
      "source": [
        "for epoch in range(2):\n",
        "    for i, batch in enumerate(tqdm(dataloader)):\n",
        "\n",
        "        batch = {k: v.to(device) for k, v in batch.items()}\n",
        "\n",
        "        outputs = model(batch[\"input_ids\"] )\n",
        "        # We need to reformat the tensors for the loss function.\n",
        "        # They need to have the shape (N, C) and (N,) where N is the number\n",
        "        # of tokens and C the number of classes.\n",
        "        predictions = torch.flatten(outputs[0], 0, 1)\n",
        "        expected = torch.flatten(batch[\"labels\"].long(), 0, 1)\n",
        "\n",
        "        loss = loss_function(predictions, expected)\n",
        "        loss.backward()\n",
        "\n",
        "        optimizer.step()\n",
        "        optimizer.zero_grad()\n",
        "        if i % 10000 == 0:\n",
        "            print(f\"loss: {loss}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A9JoeQeTVXzO"
      },
      "source": [
        "Then we can save the adapter and head we trained with `model.save_adapter` and `model.save_head` for future use."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:38:46.554573900Z",
          "start_time": "2023-08-17T10:38:46.523329900Z"
        },
        "id": "n8XOhPzCRXlc"
      },
      "outputs": [],
      "source": [
        "model.save_adapter('adapter/', 'ner')\n",
        "model.save_head(\"head/\", \"ner_head\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-0K2QlaAC8Tt"
      },
      "source": [
        "For evaluating our trained adapter, we use a confusion matrix to display how often a token with label x was classified as a class with label y. We can see that the predictions are in most cases correct. From the confusion matrix, we can additionally see which labels were wrongly predicted."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T10:50:40.781717500Z",
          "start_time": "2023-08-17T10:42:00.855101Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 205,
          "referenced_widgets": [
            "42d34c55d5c541e18fb8ee0200442942",
            "94f0840b29064caaaebdcab95a1af5bf",
            "4fea84ddea67434bbe74d61f99ad4138",
            "4b68bf321a2a4a2995c94d71f21bfb05",
            "285320fc758347d2b70f2cdadb783616",
            "84d4167e5551434e984e4b16994a911d",
            "63a933396b1f41979f9295ece6a5154c",
            "fe40b664e00349c8a537edb933558d82",
            "bcb9a39e39c94e47a5741426fffe16fe",
            "5c534559d50947c5b7dd911b8b4b7f5b",
            "15b7a8deab834bb8b75b18a5865bc98c"
          ]
        },
        "id": "hNsyUkmgJ6FV",
        "outputId": "87c4cc1a-7302-4d84-de49-6805b1630a82"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "42d34c55d5c541e18fb8ee0200442942",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  0%|          | 0/3453 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[[1754582     114      96     118     184     127      71     126     140]\n",
            " [     55    1449      77      25       2       6       0       1       0]\n",
            " [    111       3    3394       2      39       0       4       0       5]\n",
            " [     90      18       0    1389      39      91       0      33       0]\n",
            " [    117       1      26       8    1659       1     105       0      28]\n",
            " [    100       5       0      67       8    1435      25      23       1]\n",
            " [     63       0       5       0      45       2     610       0       1]\n",
            " [    109       7       0      33       2      19       0     504      26]\n",
            " [     85       1      16       6      47       2      39      20     294]]\n"
          ]
        }
      ],
      "source": [
        "from sklearn.metrics import confusion_matrix\n",
        "model.to(device)\n",
        "model.eval()\n",
        "predictions_list = []\n",
        "expected_list = []\n",
        "for i, batch in enumerate(tqdm(evaluate_dataloader)):\n",
        "    batch = {k: v.to(device) for k, v in batch.items()}\n",
        "    outputs = model(batch[\"input_ids\"], adapter_names=['ner'])\n",
        "    predictions = torch.argmax(outputs[0], 2)\n",
        "    expected = batch[\"labels\"].float()\n",
        "\n",
        "    predictions_list.append(predictions)\n",
        "    expected_list.append(expected)\n",
        "\n",
        "print(confusion_matrix(torch.flatten(torch.cat(expected_list)).cpu().numpy(),\n",
        "                 torch.flatten(torch.cat(predictions_list)).cpu().numpy()))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
